package com.example.security;

import com.example.config.RateLimitService;
import com.example.util.SecurityUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.servlet.*;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import java.io.IOException;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * DDoS/CC攻击防护过滤器
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class DDoSProtectionFilter implements Filter {

    private final RateLimitService rateLimitService;
    private final RedisTemplate<String, Object> redisTemplate;
    private final SecurityUtils securityUtils;
    private final ObjectMapper objectMapper;

    @Value("${security.ddos.enabled:true}")
    private boolean ddosProtectionEnabled;

    @Value("${security.ddos.ip-request-limit:100}")
    private int ipRequestLimit;

    @Value("${security.ddos.ip-time-window:60}")
    private int ipTimeWindow;

    @Value("${security.ddos.global-request-limit:10000}")
    private int globalRequestLimit;

    @Value("${security.ddos.global-time-window:60}")
    private int globalTimeWindow;

    @Value("${security.ddos.suspicious-threshold:50}")
    private int suspiciousThreshold;

    @Value("${security.ddos.block-duration:300}")
    private int blockDuration;

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {

        if (!ddosProtectionEnabled) {
            chain.doFilter(request, response);
            return;
        }

        HttpServletRequest httpRequest = (HttpServletRequest) request;
        HttpServletResponse httpResponse = (HttpServletResponse) response;

        String clientIp = getClientIp(httpRequest);
        String userAgent = httpRequest.getHeader("User-Agent");
        String requestUri = httpRequest.getRequestURI();

        try {
            // 检查IP是否被封禁
            if (isIpBlocked(clientIp)) {
                log.warn("被封禁IP尝试访问: ip={}, uri={}", clientIp, requestUri);
                sendBlockedResponse(httpResponse, "IP_BLOCKED", "IP已被封禁");
                return;
            }

            // 检查是否为恶意请求
            if (isMaliciousRequest(httpRequest)) {
                log.warn("检测到恶意请求: ip={}, uri={}, userAgent={}", clientIp, requestUri, userAgent);
                blockIp(clientIp, "恶意请求");
                sendBlockedResponse(httpResponse, "MALICIOUS_REQUEST", "恶意请求");
                return;
            }

            // IP级别限流
            if (!rateLimitService.isAllowed("ddos_ip:" + clientIp, ipRequestLimit, ipTimeWindow)) {
                log.warn("IP请求频率过高: ip={}, uri={}", clientIp, requestUri);
                
                // 记录可疑行为
                recordSuspiciousActivity(clientIp, "HIGH_FREQUENCY_REQUESTS");
                
                // 检查是否需要封禁
                if (shouldBlockIp(clientIp)) {
                    blockIp(clientIp, "高频请求");
                }
                
                sendBlockedResponse(httpResponse, "RATE_LIMIT_EXCEEDED", "请求过于频繁");
                return;
            }

            // 全局限流
            if (!rateLimitService.isAllowed("ddos_global", globalRequestLimit, globalTimeWindow)) {
                log.warn("全局请求频率过高: ip={}, uri={}", clientIp, requestUri);
                sendBlockedResponse(httpResponse, "GLOBAL_RATE_LIMIT", "系统繁忙，请稍后再试");
                return;
            }

            // 检查请求模式
            if (isAbnormalRequestPattern(httpRequest)) {
                log.warn("检测到异常请求模式: ip={}, uri={}", clientIp, requestUri);
                recordSuspiciousActivity(clientIp, "ABNORMAL_PATTERN");
            }

            // 记录正常请求
            recordRequest(clientIp, requestUri, userAgent);

            // 继续处理请求
            chain.doFilter(request, response);

        } catch (Exception e) {
            log.error("DDoS防护过滤器异常: {}", e.getMessage(), e);
            chain.doFilter(request, response);
        }
    }

    /**
     * 检查IP是否被封禁
     */
    private boolean isIpBlocked(String clientIp) {
        try {
            String blockKey = "ddos_blocked_ip:" + clientIp;
            return redisTemplate.hasKey(blockKey);
        } catch (Exception e) {
            log.error("检查IP封禁状态异常: {}", e.getMessage());
            return false;
        }
    }

    /**
     * 封禁IP
     */
    private void blockIp(String clientIp, String reason) {
        try {
            String blockKey = "ddos_blocked_ip:" + clientIp;
            Map<String, Object> blockInfo = new HashMap<>();
            blockInfo.put("reason", reason);
            blockInfo.put("blockedAt", LocalDateTime.now().toString());
            blockInfo.put("blockedBy", "DDoS_PROTECTION");
            
            redisTemplate.opsForValue().set(blockKey, blockInfo, blockDuration, TimeUnit.SECONDS);
            
            log.warn("IP已被封禁: ip={}, reason={}, duration={}秒", clientIp, reason, blockDuration);
            
            // 记录封禁事件
            recordSecurityEvent("IP_BLOCKED", clientIp, reason);
            
        } catch (Exception e) {
            log.error("封禁IP异常: {}", e.getMessage(), e);
        }
    }

    /**
     * 检查是否为恶意请求
     */
    private boolean isMaliciousRequest(HttpServletRequest request) {
        String userAgent = request.getHeader("User-Agent");
        String requestUri = request.getRequestURI();
        String queryString = request.getQueryString();

        // 检查User-Agent
        if (userAgent == null || userAgent.trim().isEmpty()) {
            return true;
        }

        // 检查已知的恶意User-Agent
        String[] maliciousUserAgents = {
            "sqlmap", "nmap", "nikto", "dirb", "gobuster", "masscan",
            "python-requests", "curl", "wget", "httperf", "ab"
        };
        
        for (String malicious : maliciousUserAgents) {
            if (userAgent.toLowerCase().contains(malicious)) {
                return true;
            }
        }

        // 检查SQL注入攻击
        if (securityUtils.containsSqlInjection(requestUri) || 
            (queryString != null && securityUtils.containsSqlInjection(queryString))) {
            return true;
        }

        // 检查XSS攻击
        if (securityUtils.containsXss(requestUri) || 
            (queryString != null && securityUtils.containsXss(queryString))) {
            return true;
        }

        // 检查路径遍历攻击
        if (securityUtils.containsPathTraversal(requestUri)) {
            return true;
        }

        // 检查命令注入
        if (containsCommandInjection(requestUri) || 
            (queryString != null && containsCommandInjection(queryString))) {
            return true;
        }

        return false;
    }

    /**
     * 检查命令注入
     */
    private boolean containsCommandInjection(String input) {
        if (input == null) return false;
        
        String[] commandPatterns = {
            ";", "|", "&", "$", "`", "$(", "&&", "||", 
            "exec", "system", "shell_exec", "passthru",
            "eval", "base64_decode", "file_get_contents"
        };
        
        String lowerInput = input.toLowerCase();
        for (String pattern : commandPatterns) {
            if (lowerInput.contains(pattern)) {
                return true;
            }
        }
        
        return false;
    }

    /**
     * 检查异常请求模式
     */
    private boolean isAbnormalRequestPattern(HttpServletRequest request) {
        String requestUri = request.getRequestURI();
        String method = request.getMethod();
        
        // 检查异常长的URI
        if (requestUri.length() > 2048) {
            return true;
        }
        
        // 检查异常的HTTP方法
        if (!method.matches("GET|POST|PUT|DELETE|PATCH|OPTIONS|HEAD")) {
            return true;
        }
        
        // 检查异常的文件扩展名请求
        String[] suspiciousExtensions = {
            ".php", ".asp", ".jsp", ".cgi", ".pl", ".py", ".sh", ".bat", ".exe"
        };
        
        for (String ext : suspiciousExtensions) {
            if (requestUri.toLowerCase().endsWith(ext)) {
                return true;
            }
        }
        
        return false;
    }

    /**
     * 记录可疑活动
     */
    private void recordSuspiciousActivity(String clientIp, String activityType) {
        try {
            String key = "ddos_suspicious:" + clientIp;
            redisTemplate.opsForHash().increment(key, activityType, 1);
            redisTemplate.expire(key, 3600, TimeUnit.SECONDS); // 1小时过期
        } catch (Exception e) {
            log.error("记录可疑活动异常: {}", e.getMessage());
        }
    }

    /**
     * 检查是否应该封禁IP
     */
    private boolean shouldBlockIp(String clientIp) {
        try {
            String key = "ddos_suspicious:" + clientIp;
            Map<Object, Object> activities = redisTemplate.opsForHash().entries(key);
            
            int totalSuspiciousCount = activities.values().stream()
                    .mapToInt(v -> Integer.parseInt(v.toString()))
                    .sum();
            
            return totalSuspiciousCount >= suspiciousThreshold;
        } catch (Exception e) {
            log.error("检查IP封禁条件异常: {}", e.getMessage());
            return false;
        }
    }

    /**
     * 记录请求
     */
    private void recordRequest(String clientIp, String requestUri, String userAgent) {
        try {
            // 记录请求统计
            String statsKey = "ddos_stats:" + LocalDateTime.now().getHour();
            redisTemplate.opsForHash().increment(statsKey, "total_requests", 1);
            redisTemplate.opsForHash().increment(statsKey, "ip:" + clientIp, 1);
            redisTemplate.expire(statsKey, 3600, TimeUnit.SECONDS);
        } catch (Exception e) {
            log.debug("记录请求统计异常: {}", e.getMessage());
        }
    }

    /**
     * 记录安全事件
     */
    private void recordSecurityEvent(String eventType, String clientIp, String details) {
        try {
            Map<String, Object> event = new HashMap<>();
            event.put("type", eventType);
            event.put("ip", clientIp);
            event.put("details", details);
            event.put("timestamp", LocalDateTime.now().toString());
            
            String eventKey = "security_event:" + System.currentTimeMillis();
            redisTemplate.opsForValue().set(eventKey, event, 7, TimeUnit.DAYS);
        } catch (Exception e) {
            log.error("记录安全事件异常: {}", e.getMessage());
        }
    }

    /**
     * 发送封禁响应
     */
    private void sendBlockedResponse(HttpServletResponse response, String errorCode, String message) 
            throws IOException {
        response.setStatus(429); // 429 Too Many Requests
        response.setContentType("application/json;charset=UTF-8");
        response.setCharacterEncoding("UTF-8");
        
        Map<String, Object> errorResponse = new HashMap<>();
        errorResponse.put("success", false);
        errorResponse.put("error", errorCode);
        errorResponse.put("message", message);
        errorResponse.put("timestamp", LocalDateTime.now().toString());
        
        response.getWriter().write(objectMapper.writeValueAsString(errorResponse));
    }

    /**
     * 获取客户端IP
     */
    private String getClientIp(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_CLIENT_IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("HTTP_X_FORWARDED_FOR");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }
        
        return ip != null ? ip : "unknown";
    }
}
